# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Label map utility functions."""
from absl import logging
from six.moves import range


def _validate_label_map(label_map):
    """Checks if a label map is valid.

    Args:
      label_map: StringIntLabelMap to validate.

    Raises:
      ValueError: if label map is invalid.
    """
    for item in label_map.item:
        if item.id < 0:
            raise ValueError("Label map ids should be >= 0.")
        if (
            item.id == 0
            and item.name != "background"
            and item.display_name != "background"
        ):
            raise ValueError("Label map id 0 is reserved for the background label")


def create_category_index(categories):
    """Creates dictionary of COCO compatible categories keyed by category id.

    Args:
      categories: a list of dicts, each of which has the following keys:
        'id': (required) an integer id uniquely identifying this category.
        'name': (required) string representing category name
          e.g., 'cat', 'dog', 'pizza'.

    Returns:
      category_index: a dict containing the same entries as categories, but keyed
        by the 'id' field of each category.
    """
    category_index = {}
    for cat in categories:
        category_index[cat["id"]] = cat
    return category_index


def get_max_label_map_index(label_map):
    """Get maximum index in label map.

    Args:
      label_map: a StringIntLabelMapProto

    Returns:
      an integer
    """
    return max([item.id for item in label_map.item])


def convert_label_map_to_categories(label_map, max_num_classes, use_display_name=True):
    """Given label map proto returns categories list compatible with eval.

    This function converts label map proto and returns a list of dicts, each of
    which  has the following keys:
      'id': (required) an integer id uniquely identifying this category.
      'name': (required) string representing category name
        e.g., 'cat', 'dog', 'pizza'.
      'keypoints': (optional) a dictionary of keypoint string 'label' to integer
        'id'.
    We only allow class into the list if its id-label_id_offset is
    between 0 (inclusive) and max_num_classes (exclusive).
    If there are several items mapping to the same id in the label map,
    we will only keep the first one in the categories list.

    Args:
      label_map: a StringIntLabelMapProto or None.  If None, a default categories
        list is created with max_num_classes categories.
      max_num_classes: maximum number of (consecutive) label indices to include.
      use_display_name: (boolean) choose whether to load 'display_name' field as
        category name.  If False or if the display_name field does not exist, uses
        'name' field as category names instead.

    Returns:
      categories: a list of dictionaries representing all possible categories.
    """
    categories = []
    list_of_ids_already_added = []
    if not label_map:
        label_id_offset = 1
        for class_id in range(max_num_classes):
            categories.append(
                {
                    "id": class_id + label_id_offset,
                    "name": "category_{}".format(class_id + label_id_offset),
                }
            )
        return categories
    for item in label_map.item:
        if not 0 < item.id <= max_num_classes:
            logging.info(
                "Ignore item %d since it falls outside of requested " "label range.",
                item.id,
            )
            continue
        if use_display_name and item.HasField("display_name"):
            name = item.display_name
        else:
            name = item.name
        if item.id not in list_of_ids_already_added:
            list_of_ids_already_added.append(item.id)
            category = {"id": item.id, "name": name}
            if item.keypoints:
                keypoints = {}
                list_of_keypoint_ids = []
                for kv in item.keypoints:
                    if kv.id in list_of_keypoint_ids:
                        raise ValueError(
                            "Duplicate keypoint ids are not allowed. "
                            "Found {} more than once".format(kv.id)
                        )
                    keypoints[kv.label] = kv.id
                    list_of_keypoint_ids.append(kv.id)
                category["keypoints"] = keypoints
            categories.append(category)
    return categories


def create_class_agnostic_category_index():
    """Creates a category index with a single `object` class."""
    return {1: {"id": 1, "name": "object"}}
